{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 蒸馏通道裁剪模型示例\n",
    "本示例介绍使用更高精度的[YOLOv3-ResNet34](../../configs/yolov3_r34.yml)模型蒸馏经通道裁剪的[YOLOv3-MobileNet](../../configs/yolov3_mobilenet_v1.yml)模型。脚本可参照蒸馏脚本[distill.py](../distillation/distill.py)和通道裁剪脚本[prune.py](../prune/prune.py)简单修改得到，蒸馏过程采用细粒度损失来蒸馏YOLOv3输出层特征图。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "切换到PaddleDetection根目录，设置环境变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workplace/PaddleDetection\n"
     ]
    }
   ],
   "source": [
    "% cd ../.."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导入依赖包，注意须同时导入蒸馏和通道裁剪的相关接口"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "from collections import OrderedDict\n",
    "from paddleslim.dist.single_distiller import merge, l2_loss\n",
    "from paddleslim.prune import Pruner\n",
    "from paddleslim.analysis import flops\n",
    "\n",
    "from paddle import fluid\n",
    "from ppdet.core.workspace import load_config, merge_config, create\n",
    "from ppdet.data.reader import create_reader\n",
    "from ppdet.utils.eval_utils import parse_fetches, eval_results, eval_run\n",
    "from ppdet.utils.stats import TrainingStats\n",
    "from ppdet.utils.cli import ArgsParser\n",
    "from ppdet.utils.check import check_gpu\n",
    "import ppdet.utils.checkpoint as checkpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义细粒度的蒸馏损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_distill(split_output_names, weight):\n",
    "    \"\"\"\n",
    "    Add fine grained distillation losses.\n",
    "    Each loss is composed by distill_reg_loss, distill_cls_loss and\n",
    "    distill_obj_loss\n",
    "    \"\"\"\n",
    "    student_var = []\n",
    "    for name in split_output_names:\n",
    "        student_var.append(fluid.default_main_program().global_block().var(\n",
    "            name))\n",
    "    s_x0, s_y0, s_w0, s_h0, s_obj0, s_cls0 = student_var[0:6]\n",
    "    s_x1, s_y1, s_w1, s_h1, s_obj1, s_cls1 = student_var[6:12]\n",
    "    s_x2, s_y2, s_w2, s_h2, s_obj2, s_cls2 = student_var[12:18]\n",
    "    teacher_var = []\n",
    "    for name in split_output_names:\n",
    "        teacher_var.append(fluid.default_main_program().global_block().var(\n",
    "            'teacher_' + name))\n",
    "    t_x0, t_y0, t_w0, t_h0, t_obj0, t_cls0 = teacher_var[0:6]\n",
    "    t_x1, t_y1, t_w1, t_h1, t_obj1, t_cls1 = teacher_var[6:12]\n",
    "    t_x2, t_y2, t_w2, t_h2, t_obj2, t_cls2 = teacher_var[12:18]\n",
    "\n",
    "    def obj_weighted_reg(sx, sy, sw, sh, tx, ty, tw, th, tobj):\n",
    "        loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(\n",
    "            sx, fluid.layers.sigmoid(tx))\n",
    "        loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(\n",
    "            sy, fluid.layers.sigmoid(ty))\n",
    "        loss_w = fluid.layers.abs(sw - tw)\n",
    "        loss_h = fluid.layers.abs(sh - th)\n",
    "        loss = fluid.layers.sum([loss_x, loss_y, loss_w, loss_h])\n",
    "        weighted_loss = fluid.layers.reduce_mean(loss *\n",
    "                                                 fluid.layers.sigmoid(tobj))\n",
    "        return weighted_loss\n",
    "\n",
    "    def obj_weighted_cls(scls, tcls, tobj):\n",
    "        loss = fluid.layers.sigmoid_cross_entropy_with_logits(\n",
    "            scls, fluid.layers.sigmoid(tcls))\n",
    "        weighted_loss = fluid.layers.reduce_mean(\n",
    "            fluid.layers.elementwise_mul(\n",
    "                loss, fluid.layers.sigmoid(tobj), axis=0))\n",
    "        return weighted_loss\n",
    "\n",
    "    def obj_loss(sobj, tobj):\n",
    "        obj_mask = fluid.layers.cast(tobj > 0., dtype=\"float32\")\n",
    "        obj_mask.stop_gradient = True\n",
    "        loss = fluid.layers.reduce_mean(\n",
    "            fluid.layers.sigmoid_cross_entropy_with_logits(sobj, obj_mask))\n",
    "        return loss\n",
    "\n",
    "    distill_reg_loss0 = obj_weighted_reg(s_x0, s_y0, s_w0, s_h0, t_x0, t_y0,\n",
    "                                         t_w0, t_h0, t_obj0)\n",
    "    distill_reg_loss1 = obj_weighted_reg(s_x1, s_y1, s_w1, s_h1, t_x1, t_y1,\n",
    "                                         t_w1, t_h1, t_obj1)\n",
    "    distill_reg_loss2 = obj_weighted_reg(s_x2, s_y2, s_w2, s_h2, t_x2, t_y2,\n",
    "                                         t_w2, t_h2, t_obj2)\n",
    "    distill_reg_loss = fluid.layers.sum(\n",
    "        [distill_reg_loss0, distill_reg_loss1, distill_reg_loss2])\n",
    "\n",
    "    distill_cls_loss0 = obj_weighted_cls(s_cls0, t_cls0, t_obj0)\n",
    "    distill_cls_loss1 = obj_weighted_cls(s_cls1, t_cls1, t_obj1)\n",
    "    distill_cls_loss2 = obj_weighted_cls(s_cls2, t_cls2, t_obj2)\n",
    "    distill_cls_loss = fluid.layers.sum(\n",
    "        [distill_cls_loss0, distill_cls_loss1, distill_cls_loss2])\n",
    "\n",
    "    distill_obj_loss0 = obj_loss(s_obj0, t_obj0)\n",
    "    distill_obj_loss1 = obj_loss(s_obj1, t_obj1)\n",
    "    distill_obj_loss2 = obj_loss(s_obj2, t_obj2)\n",
    "    distill_obj_loss = fluid.layers.sum(\n",
    "        [distill_obj_loss0, distill_obj_loss1, distill_obj_loss2])\n",
    "    loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss) * weight\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取配置文件，设置use_fined_grained_loss=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = load_config(\"./configs/yolov3_mobilenet_v1.yml\")\n",
    "merge_config({'use_fine_grained_loss': True})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建执行器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "devices_num = fluid.core.get_cuda_device_count()\n",
    "place = fluid.CUDAPlace(0)\n",
    "# devices_num = int(os.environ.get('CPU_NUM', 1))\n",
    "# place = fluid.CPUPlace()\n",
    "exe = fluid.Executor(place)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造训练模型和reader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading annotations into memory...\n",
      "Done (t=24.56s)\n",
      "creating index...\n",
      "index created!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:02:05,272-WARNING: Found an invalid bbox in annotations: im_id: 550395, area: 0.0 x1: 9.98, y1: 188.56, x2: 14.52, y2: 188.56.\n",
      "2020-02-07 08:02:17,043-WARNING: Found an invalid bbox in annotations: im_id: 200365, area: 0.0 x1: 296.65, y1: 388.33, x2: 296.68, y2: 388.33.\n",
      "2020-02-07 08:02:19,653-INFO: 118287 samples in file dataset/coco/annotations/instances_train2017.json\n",
      "2020-02-07 08:02:25,768-INFO: places would be ommited when DataLoader is not iterable\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<paddle.fluid.reader.GeneratorLoader at 0x7f5071d63890>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "main_arch = cfg.architecture\n",
    "# build program\n",
    "model = create(main_arch)\n",
    "inputs_def = cfg['TrainReader']['inputs_def']\n",
    "train_feed_vars, train_loader = model.build_inputs(**inputs_def)\n",
    "train_fetches = model.train(train_feed_vars)\n",
    "loss = train_fetches['loss']\n",
    "\n",
    "start_iter = 0\n",
    "train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg)\n",
    "train_loader.set_sample_list_generator(train_reader, place)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造评估模型和reader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading annotations into memory...\n",
      "Done (t=0.84s)\n",
      "creating index...\n",
      "index created!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:02:32,708-INFO: 5000 samples in file dataset/coco/annotations/instances_val2017.json\n",
      "2020-02-07 08:02:32,805-INFO: places would be ommited when DataLoader is not iterable\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<paddle.fluid.reader.GeneratorLoader at 0x7f4fd29fd090>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_prog = fluid.Program()\n",
    "with fluid.program_guard(eval_prog, fluid.default_startup_program()):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = create(main_arch)\n",
    "        inputs_def = cfg['EvalReader']['inputs_def']\n",
    "        test_feed_vars, eval_loader = model.build_inputs(**inputs_def)\n",
    "        fetches = model.eval(test_feed_vars)\n",
    "eval_prog = eval_prog.clone(True)\n",
    "\n",
    "eval_reader = create_reader(cfg.EvalReader)\n",
    "eval_loader.set_sample_list_generator(eval_reader, place)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造teacher模型并导入权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:02:38,933-INFO: Found /root/.cache/paddle/weights/yolov3_r34\n",
      "2020-02-07 08:02:38,935-INFO: Loading parameters from /root/.cache/paddle/weights/yolov3_r34...\n"
     ]
    }
   ],
   "source": [
    "teacher_cfg = load_config(\"./configs/yolov3_r34.yml\")\n",
    "merge_config({'use_fine_grained_loss': True})\n",
    "teacher_arch = teacher_cfg.architecture\n",
    "teacher_program = fluid.Program()\n",
    "teacher_startup_program = fluid.Program()\n",
    "\n",
    "with fluid.program_guard(teacher_program, teacher_startup_program):\n",
    "    with fluid.unique_name.guard():\n",
    "        teacher_feed_vars = OrderedDict()\n",
    "        for name, var in train_feed_vars.items():\n",
    "            teacher_feed_vars[name] = teacher_program.global_block(\n",
    "            )._clone_variable(\n",
    "                var, force_persistable=False)\n",
    "        model = create(teacher_arch)\n",
    "        train_fetches = model.train(teacher_feed_vars)\n",
    "        teacher_loss = train_fetches['loss']\n",
    "\n",
    "exe.run(teacher_startup_program)\n",
    "checkpoint.load_params(exe, teacher_program, \"https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar\")\n",
    "teacher_program = teacher_program.clone(for_test=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "合并program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name_map = {\n",
    "    'target0': 'target0',\n",
    "    'target1': 'target1',\n",
    "    'target2': 'target2',\n",
    "    'image': 'image',\n",
    "    'gt_bbox': 'gt_bbox',\n",
    "    'gt_class': 'gt_class',\n",
    "    'gt_score': 'gt_score'\n",
    "}\n",
    "merge(teacher_program, fluid.default_main_program(), data_name_map, place)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造蒸馏损失和优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "yolo_output_names = [\n",
    "    'strided_slice_0.tmp_0', 'strided_slice_1.tmp_0',\n",
    "    'strided_slice_2.tmp_0', 'strided_slice_3.tmp_0',\n",
    "    'strided_slice_4.tmp_0', 'transpose_0.tmp_0', 'strided_slice_5.tmp_0',\n",
    "    'strided_slice_6.tmp_0', 'strided_slice_7.tmp_0',\n",
    "    'strided_slice_8.tmp_0', 'strided_slice_9.tmp_0', 'transpose_2.tmp_0',\n",
    "    'strided_slice_10.tmp_0', 'strided_slice_11.tmp_0',\n",
    "    'strided_slice_12.tmp_0', 'strided_slice_13.tmp_0',\n",
    "    'strided_slice_14.tmp_0', 'transpose_4.tmp_0'\n",
    "]\n",
    "    \n",
    "distill_loss = split_distill(yolo_output_names, 1000)\n",
    "loss = distill_loss + loss\n",
    "lr_builder = create('LearningRate')\n",
    "optim_builder = create('OptimizerBuilder')\n",
    "lr = lr_builder()\n",
    "opt = optim_builder(lr)\n",
    "opt.minimize(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导入待裁剪模型裁剪前全部权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:03:39,997-INFO: Found /root/.cache/paddle/weights/yolov3_mobilenet_v1\n",
      "2020-02-07 08:03:39,999-INFO: Loading parameters from /root/.cache/paddle/weights/yolov3_mobilenet_v1...\n"
     ]
    }
   ],
   "source": [
    "exe.run(fluid.default_startup_program())\n",
    "checkpoint.load_params(exe, fluid.default_main_program(), \"https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "裁剪训练和评估program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pruned params: ['yolo_block.0.0.0.conv.weights', 'yolo_block.0.0.1.conv.weights', 'yolo_block.0.1.0.conv.weights', 'yolo_block.0.1.1.conv.weights', 'yolo_block.0.2.conv.weights', 'yolo_block.0.tip.conv.weights', 'yolo_block.1.0.0.conv.weights', 'yolo_block.1.0.1.conv.weights', 'yolo_block.1.1.0.conv.weights', 'yolo_block.1.1.1.conv.weights', 'yolo_block.1.2.conv.weights', 'yolo_block.1.tip.conv.weights', 'yolo_block.2.0.0.conv.weights', 'yolo_block.2.0.1.conv.weights', 'yolo_block.2.1.0.conv.weights', 'yolo_block.2.1.1.conv.weights', 'yolo_block.2.2.conv.weights', 'yolo_block.2.tip.conv.weights']\n",
      "pruned ratios: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]\n",
      "FLOPs -0.675602593026; total FLOPs: 24531648.0; pruned FLOPs: 7958003.0\n"
     ]
    }
   ],
   "source": [
    "pruned_params = [\"yolo_block.0.0.0.conv.weights\",                                                                          \n",
    "                \"yolo_block.0.0.1.conv.weights\",                                                                          \n",
    "                \"yolo_block.0.1.0.conv.weights\",                                                                          \n",
    "                \"yolo_block.0.1.1.conv.weights\",                                                                          \n",
    "                \"yolo_block.0.2.conv.weights\",                                                                            \n",
    "                \"yolo_block.0.tip.conv.weights\",                                                                          \n",
    "                \"yolo_block.1.0.0.conv.weights\",                                                                          \n",
    "                \"yolo_block.1.0.1.conv.weights\",                                                                          \n",
    "                \"yolo_block.1.1.0.conv.weights\",                                                                          \n",
    "                \"yolo_block.1.1.1.conv.weights\",                                                                          \n",
    "                \"yolo_block.1.2.conv.weights\",                                                                            \n",
    "                \"yolo_block.1.tip.conv.weights\",                                                                          \n",
    "                \"yolo_block.2.0.0.conv.weights\",                                                                          \n",
    "                \"yolo_block.2.0.1.conv.weights\",                                                                          \n",
    "                \"yolo_block.2.1.0.conv.weights\",                                                                          \n",
    "                \"yolo_block.2.1.1.conv.weights\",                                                                          \n",
    "                \"yolo_block.2.2.conv.weights\",                                                                            \n",
    "                \"yolo_block.2.tip.conv.weights\"]\n",
    "pruned_ratios = [0.5] * 6 + [0.7] * 6 + [0.8] * 6\n",
    "\n",
    "print(\"pruned params: {}\".format(pruned_params))\n",
    "print(\"pruned ratios: {}\".format(pruned_ratios))\n",
    "\n",
    "pruner = Pruner()\n",
    "distill_prog = pruner.prune(\n",
    "    fluid.default_main_program(),\n",
    "    fluid.global_scope(),\n",
    "    params=pruned_params,\n",
    "    ratios=pruned_ratios,\n",
    "    place=place,\n",
    "    only_graph=False)[0]\n",
    "\n",
    "base_flops = flops(eval_prog)\n",
    "eval_prog = pruner.prune(\n",
    "    eval_prog,\n",
    "    fluid.global_scope(),\n",
    "    params=pruned_params,\n",
    "    ratios=pruned_ratios,\n",
    "    place=place,\n",
    "    only_graph=True)[0]\n",
    "pruned_flops = flops(eval_prog)\n",
    "print(\"FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}\".format(float(base_flops - pruned_flops)/base_flops, base_flops, pruned_flops))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "编译训练和评估program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "build_strategy = fluid.BuildStrategy()\n",
    "build_strategy.fuse_all_reduce_ops = False\n",
    "build_strategy.fuse_all_optimizer_ops = False\n",
    "build_strategy.fuse_elewise_add_act_ops = True\n",
    "# only enable sync_bn in multi GPU devices\n",
    "sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'\n",
    "build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \\\n",
    "    and cfg.use_gpu\n",
    "\n",
    "exec_strategy = fluid.ExecutionStrategy()\n",
    "# iteration number when CompiledProgram tries to drop local execution scopes.\n",
    "# Set it to be 1 to save memory usages, so that unused variables in\n",
    "# local execution scopes can be deleted after each iteration.\n",
    "exec_strategy.num_iteration_per_drop_scope = 1\n",
    "\n",
    "parallel_main = fluid.CompiledProgram(distill_prog).with_data_parallel(\n",
    "    loss_name=loss.name,\n",
    "    build_strategy=build_strategy,\n",
    "    exec_strategy=exec_strategy)\n",
    "compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0 lr 0.000000, loss 1224.825684, distill_loss 362.955292, teacher_loss 39.179062\n"
     ]
    }
   ],
   "source": [
    "# parse eval fetches\n",
    "extra_keys = []\n",
    "if cfg.metric == 'COCO':\n",
    "    extra_keys = ['im_info', 'im_id', 'im_shape']\n",
    "if cfg.metric == 'VOC':\n",
    "    extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']\n",
    "eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,\n",
    "                                                     extra_keys)\n",
    "\n",
    "# whether output bbox is normalized in model output layer\n",
    "is_bbox_normalized = False\n",
    "map_type = cfg.map_type if 'map_type' in cfg else '11point'\n",
    "best_box_ap_list = [0.0, 0]  #[map, iter]\n",
    "save_dir = os.path.join(cfg.save_dir, 'yolov3_mobilenet_v1')\n",
    "\n",
    "train_loader.start()\n",
    "for step_id in range(start_iter, cfg.max_iters):\n",
    "    teacher_loss_np, distill_loss_np, loss_np, lr_np = exe.run(\n",
    "        parallel_main,\n",
    "        fetch_list=[\n",
    "            'teacher_' + teacher_loss.name, distill_loss.name, loss.name,\n",
    "            lr.name\n",
    "        ])\n",
    "    if step_id % 20 == 0:\n",
    "        print(\n",
    "            \"step {} lr {:.6f}, loss {:.6f}, distill_loss {:.6f}, teacher_loss {:.6f}\".\n",
    "            format(step_id, lr_np[0], loss_np[0], distill_loss_np[0],\n",
    "                    teacher_loss_np[0]))\n",
    "    if step_id % cfg.snapshot_iter == 0 and step_id != 0 or step_id == cfg.max_iters - 1:\n",
    "        save_name = str(\n",
    "            step_id) if step_id != cfg.max_iters - 1 else \"model_final\"\n",
    "        checkpoint.save(exe,\n",
    "                        distill_prog,\n",
    "                        os.path.join(save_dir, save_name))\n",
    "        # eval\n",
    "        results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys,\n",
    "                                  eval_values, eval_cls)\n",
    "        resolution = None\n",
    "        box_ap_stats = eval_results(results, cfg.metric, cfg.num_classes,\n",
    "                                        resolution, is_bbox_normalized,\n",
    "                                        FLAGS.output_eval, map_type,\n",
    "                                        cfg['EvalReader']['dataset'])\n",
    "\n",
    "        if box_ap_stats[0] > best_box_ap_list[0]:\n",
    "            best_box_ap_list[0] = box_ap_stats[0]\n",
    "            best_box_ap_list[1] = step_id\n",
    "            checkpoint.save(exe,\n",
    "                                distill_prog,\n",
    "                                os.path.join(\"./\", \"best_model\"))\n",
    "        print(\"Best test box ap: {}, in step: {}\".format(\n",
    "                best_box_ap_list[0], best_box_ap_list[1]))\n",
    "    train_loader.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们也提供了一键式启动蒸馏通道裁剪模型训练脚本[distill_pruned_model.py](./distill_pruned_model.py)和蒸馏通道裁剪模型库，请参考[蒸馏通道裁剪模型](./README.md)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
